# Kernel: sgemm_nn_128x128_vec

<CONSTANT_MAPPING>
  addr_zero  : 4x<128*8*4>

  gridDimA : c[0x0][0x14]
  gridDimB : c[0x0][0x18]

  param_A[0]  : c[0x0][0x140]
  param_A[1]  : c[0x0][0x144]
  param_B[0]  : c[0x0][0x148]
  param_B[1]  : c[0x0][0x14c]
  param_C[0]  : c[0x0][0x150]
  param_C[1]  : c[0x0][0x154]
  param_alpha : c[0x0][0x158]
  param_beta  : c[0x0][0x15c]
  param_lda   : c[0x0][0x160]
  param_ldb8  : c[0x0][0x164]
  param_ldc   : c[0x0][0x168]
  param_m     : c[0x0][0x16c]
  param_n     : c[0x0][0x170]
  param_k     : c[0x0][0x174]
</CONSTANT_MAPPING>

<REGISTER_MAPPING>
  64-91 ~ blkA, blkB, blkZ, tidAX, tidBX, lda, ldb, ldaz, ldbz, tid1, tid7, txa, txb, tid31, tid128, tidBY, ta, tb, tmp_shl
  92-93 ~ tmp_param0, tmp_param1

  0-63 : czero<00-63>

  // avoid ffma single instruction bank conflict

   1,  4, 17, 20, 33, 36, 49, 52 : cx<0-7>y0
   5,  0, 21, 16, 37, 32, 53, 48 : cx<0-7>y1
   3,  6, 19, 22, 35, 38, 51, 54 : cx<0-7>y2
   7,  2, 23, 18, 39, 34, 55, 50 : cx<0-7>y3
   9, 12, 25, 28, 41, 44, 57, 60 : cx<0-7>y4
  13,  8, 29, 24, 45, 40, 61, 56 : cx<0-7>y5
  11, 14, 27, 30, 43, 46, 59, 62 : cx<0-7>y6
  15, 10, 31, 26, 47, 42, 63, 58 : cx<0-7>y7

  64-67 : j0Ay<0-3>
  68-71 : j0Bx<0-3>
  72-75 : j0Ay<4-7>
  76-79 : j0Bx<4-7>
  80-83 : j1Ay<0-3>
  84-87 : j1Bx<0-3>
  88-91 : j1Ay<4-7>
  92-95 : j1Bx<4-7>

  96-103 : loadA<0-3>, loadB<0-3>

  104-107 : trackA<0-1>, trackB<0-1>

  108-112 ~ writeAs, writeBs, k, k_and, tidAY
  // to avoid lds bank conflict with ffma 
  117 ~ readAs
  116 ~ readBs
  115 ~ tid

  64-75 ~ ldc, ci, tid_31, tid_96, tid_128, blockA, blockB, blockZ
  64-75 : c<0-7>, d3, d2, d1, d0
  76-85 : C00y<0-1>, C04y<0-1>, C08y<0-1>, C12y<0-1>
  86-101 ~ ldc1, ldc4, ldc60, ldcz, writeCs, readCs, cx<00|64>, cy<00|04|08|12>, alpha, beta, flags

</REGISTER_MAPPING>

//special to register
-:-:-:-:00 S2R tid,  SR_TID.X;
-:-:-:-:00 S2R blkA, SR_CTAID.Y;
-:-:-:-:00 S2R blkB, SR_CTAID.Z;
-:-:-:-:00 S2R blkZ, SR_CTAID.X;//blkZ=1

-:-:-:-:00 MOV k,   param_k;
-:-:-:-:00 MOV ldaz, RZ;
-:-:-:-:00 MOV ldbz, RZ;
-:-:-:-:00 MOV ldcz, RZ;
-:-:-:-:00 MOV lda, param_lda;
-:-:-:-:00 MOV ldb, param_ldb8;
-:-:-:-:00 SHR.U32 ldb, ldb, 5;//ldb is not byte
-:-:-:-:00 STS.128 [RZ + addr_zero], RZ;
<CODE>
  join('', map sprintf("-:-:-:-:00 LDS.128 czero%02d, [RZ + addr_zero];\n", $_ * 4), 0..15);
</CODE>

// tidAY  = (tid & 1) << 2
-:-:-:-:00 LOP.AND tid1,  tid,  1;
-:-:-:-:00 SHL     tidAY, tid1, 2;

// tidAX = tid >> 1
-:-:-:-:00 SHR.U32 tidAX, tid, 1;

// trackA += 4 * ((blkA*128 + tidAX) * lda + tidAY)
-:-:-:-:00 ISCADD txa, blkA, tidAX, 7;
-:-:-:-:00 IMAD   ta, lda, txa, tidAY;
-:-:-:-:00 IMAD   ta, ldaz, blkZ, ta;
// TODO(keren): 0x2?
-:-:-:-:00 MOV tmp_param0, param_A[0];
-:-:-:-:00 MOV tmp_param1, param_A[1];
-:-:-:-:00 SHL tmp_shl, ta, 0x2;
-:-:-:-:00 IADD trackA0.CC, tmp_shl, tmp_param0;
-:-:-:-:00 IADD.X trackA1, RZ, tmp_param1;

-:-:-:-:00 ISETP.LT.AND P5, PT, txa, param_m, PT;

// tidBX = (tid & 31) << 2
// tidBY = (tid >> 5) & 7
-:-:-:-:00 LOP.AND tid31, tid, 31;
-:-:-:-:00 SHL     tidBX, tid31, 2;
-:-:-:-:00 BFE.U32 tidBY, tid, 0x305; // 3 bits at position 5

// trackB += (blkB*128 + ldb*tidBY + tidBX) * 4
-:-:-:-:00 ISCADD txb, blkB, tidBX, 7;
-:-:-:-:00 IMAD tb, ldb, tidBY, txb;
-:-:-:-:00 IMAD tb, ldbz, blkZ, tb;
// TODO(keren): 0x2?
-:-:-:-:00 MOV tmp_param0, param_B[0];
-:-:-:-:00 MOV tmp_param1, param_B[1];
-:-:-:-:00 SHL tmp_shl, tb, 0x2;
-:-:-:-:00 IADD trackB0.CC, tmp_shl, tmp_param0;
-:-:-:-:00 IADD.X trackB1, RZ, tmp_param1;

// TODO(keren): blkB * 128 + tidBX < param_n
-:-:-:-:00 ISETP.LT.AND P6, PT, txb, param_n, PT;

// writeAs = 4 * (128 * tidAY + tidAX)
-:-:-:-:00 ISCADD writeAs, tidAY, tidAX, 7;
-:-:-:-:00 ISCADD writeAs, writeAs, 4x<128*8*2>, 2;

// writeBs = (128*tidBY + tidBX) * 4
-:-:-:-:00 ISCADD writeBs, tidBY, tidBX, 7;
-:-:-:-:00 ISCADD writeBs, writeBs, 4x<128*8*3>, 2;

// readAs  = (((tid & 0x70) >> 3) | (tid & 1)) << 4
-:-:-:-:00 LOP.AND readAs, tid,    0x70;
-:-:-:-:00 SHR.U32 readAs, readAs, 3;
-:-:-:-:00 LOP.OR  readAs, readAs, tid1;
-:-:-:-:00 SHL     readAs, readAs, 4;

// (keren): A allocate 128 * 8 elements
// readBs = ((tid128 >> 4) | ((tid >> 1) & 7)) << 4 + 4096;
-:-:-:-:00 LOP.AND tid128, tid,    128;
-:-:-:-:00 BFE.U32 tid7,   tid,    0x301; // 3 bits at position 1
-:-:-:-:00 SHR.U32 readBs, tid128, 4;
-:-:-:-:00 LOP.OR  readBs, readBs, tid7;
-:-:-:-:00 ISCADD  readBs, readBs, 4x<128*8>, 4;

  REMAINDER:

<CODE>
    return q{
      // k must be a multiple of 4
      // n must be a multiple of 4
      //       -
      //       -
      //       -
      //       -
      // tidBY --------------- trackB ---- loadB0
      //            blkB    tidBX
      -:-:-:-:00 @P6 LD.E.CI.128 loadB0, [trackB];

      //       -
      //       -
      // blkA  -
      //       -
      //       -
      // tidAX ---- trackA ---- loadA0 -------- loadA4

      // load if tidAY < k (tidAY == 0 if mod 4 not mod 8)
      -:-:-:-:00 ISETP.LT.AND P5, PT, tidAY, k, P5;
      -:-:-:-:00 @P5 LD.E.CI.128 loadA0, [trackA];

      // bDoRemainder = k & 7 && k > 8
      -:-:-:-:00 LOP.AND k_and, k, 7;
      -:-:-:-:00 ISETP.EQ.AND P1, PT, k_and, RZ, PT;
      -:-:-:-:00 ISETP.GT.AND P1, PT, k, 8, !P1;

      -:-:-:-:00 @!P6 LDS.128 loadB0, [RZ + addr_zero];
      -:-:-:-:00 @!P5 LDS.128 loadA0, [RZ + addr_zero];

      // ----------------------
      // ---------------------- 
      // ---------------------- tidBY
      // ----- writeBS ---- loadB0
      // tidBX

      -:-:-:-:00 STS.128 [writeBs], loadB0;

      // ------------------
      // ------------------ tidAY 0, 4
      // ------------------
      // ------ writeAS - loadA0
      // ---------------- loadA1
      // ---------------- loadA2
      // ---------------- loadA3
      // tidAX
      -:-:-:-:00 STS [writeAs + 4x<0*128>], loadA0;
      -:-:-:-:00 STS [writeAs + 4x<1*128>], loadA1;
      -:-:-:-:00 STS [writeAs + 4x<2*128>], loadA2;
      -:-:-:-:00 STS [writeAs + 4x<3*128>], loadA3;

      -:-:-:-:00 IADD   trackB0.CC, trackB0, param_ldb8;
      -:-:-:-:00 IADD.X trackB1, trackB1, RZ;

      -:-:-:-:00 IADD   trackA0.CC, trackA0, 4x<8>;
      -:-:-:-:00 IADD.X trackA1, trackA1, RZ;
    };
</CODE>

// TODO(keren): double buffer?
-:-:-:-:00 LOP.XOR readAs, readAs, 4x<128*8*2>;
-:-:-:-:00 LOP.XOR readBs, readBs, 4x<128*8*2>;
-:-:-:-:00 BAR.SYNC 0;
-:-:-:-:00 LOP.XOR writeAs, writeAs, 4x<128*8*2>;
-:-:-:-:00 LOP.XOR writeBs, writeBs, 4x<128*8*2>;

// instruction align

<CODE>
    my $k_end = 16;
    our %insert =
    (
        # P0 must be the topest
        j0c47 => "-:-:-:-:00 ISETP.GE.AND P2, PT, k, $k_end, P5;\n",
        j0c53 => "-:-:-:-:00 ISETP.GE.AND P3, PT, k, $k_end, P6;\n",
        j0c62 => "-:G:D:-:00 \@P2 LDG.E.CI.128 loadA0, [trackA];\n",
        j0c63 => "-:G:D:-:00 \@P3 LDG.E.CI.128 loadB0, [trackB];\n",

        j1c47 => "-:-:-:-:00 \@P2 IADD   trackA0.CC, trackA0, 4x<8>;\n",
        j1c53 => "-:-:-:-:00 \@P3 IADD   trackB0.CC, trackB0, param_ldb8;\n",

        j2c47 => "-:-:-:-:00 \@P2 IADD.X trackA1,    trackA1, RZ;\n",
        j2c53 => "-:-:-:-:00 \@P3 IADD.X trackB1,    trackB1, RZ;\n",

        j3c47 => "-:-:-:-:00 ISETP.GE.AND P0, PT, k, $k_end, PT;\n",
        j3c53 => "-:-:-:-:00 IADD32I k, k, -8;\n",

        j5c47 => "T:-:D:S:00 TEXDEPBAR 0x1;\n", 
        j5c53 => "-:-:D:S:00 \@P0 STS [writeAs + 4x<0*128>], loadA0;\n",
        j5c61 => "-:-:D:S:00 \@P0 STS [writeAs + 4x<1*128>], loadA1;\n",
        j5c62 => "-:-:D:S:00 \@P0 STS [writeAs + 4x<2*128>], loadA2;\n",
        j5c63 => "-:-:D:S:00 \@P0 STS [writeAs + 4x<3*128>], loadA3;\n",

        j6c47 => "T:-:D:S:00 TEXDEPBAR 0x0;\n",
        j6c53 => "-:-:D:S:00 \@P0 STS.128 [writeBs], loadB0;\n",

        j6c61 => "-:-:-:-:00 \@P0 LOP.XOR readBs, readBs, 4x<128*8*2>;\n",
        j6c62 => "-:-:-:-:00 \@P0 LOP.XOR readAs, readAs, 4x<128*8*2>;\n",
        j6c63 => "T:-:D:S:00 BAR.SYNC 0x0;\n",

        j7c47 => "-:-:-:-:00 \@P0 LOP.XOR writeAs, writeAs, 4x<128*8*2>;\n",
        j7c53 => "-:-:-:-:00 \@P0 LOP.XOR writeBs, writeBs, 4x<128*8*2>;\n",
        j7c63 => "-:-:-:-:00 \@P0 BRA.U LOOP;\n".
                 "-:-:-:-:00 \@P1 BRA.U REMAINDER;\n",
    );
    return;
</CODE>

<INCLUDE file="sgemm_common_128x128.sass"/>

